{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4f0f701f-c552-4ca1-8188-2cdfc1362f6b",
   "metadata": {},
   "source": [
    "# Uni-Mol Molecular Representation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3449ed8-2a57-4e62-9163-e32baf66e828",
   "metadata": {},
   "source": [
    "**Licenses**\n",
    "\n",
    "Copyright (c) DP Technology.\n",
    "\n",
    "This source code is licensed under the MIT license found in the\n",
    "LICENSE file in the root directory of this source tree.\n",
    "\n",
    "**Citations**\n",
    "\n",
    "Please cite the following papers if you use this notebook:\n",
    "\n",
    "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n",
    "ChemRxiv (2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d51f850-76cd-4801-bf2e-a4c53221d586",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import lmdb\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import AllChem\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import glob"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89c70ab0-da59-459d-bf1c-ac307e9e7ae5",
   "metadata": {},
   "source": [
    "### Your SMILES list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa0ce2a-b7aa-4cae-81ba-27b91c0591e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "smi_list = [\n",
    "'CC1=C(C(=O)OC2CCCC2)[C@H](c2ccccc2OC(C)C)C2=C(O)CC(C)(C)CC2=[N+]1',\n",
    "'COc1cccc(-c2nc(C(=O)NC[C@H]3CCCO3)cc3c2[nH]c2ccccc23)c1',\n",
    "'O=C1c2ccccc2C(=O)c2c1ccc(C(=O)n1nc3c4c(cccc41)C(=O)c1ccccc1-3)c2[N+](=O)[O-]',\n",
    "'COc1cc(/C=N/c2nonc2NC(C)=O)ccc1OC(C)C',\n",
    "'CCC[C@@H]1CN(Cc2ccc3nsnc3c2)C[C@H]1NS(C)(=O)=O',\n",
    "'CCc1nnc(N/C(O)=C/CCOc2ccc(OC)cc2)s1',\n",
    "'CC(C)(C)SCCN/C=C1\\C(=O)NC(=O)N(c2ccc(Br)cc2)C1=O',\n",
    "'CC(C)(C)c1nc(COc2ccc3c(c2)CCn2c-3cc(OCC3COCCO3)nc2=O)no1',\n",
    "'N#CCCNS(=O)(=O)c1ccc(/C(O)=N/c2ccccc2Oc2ccccc2Cl)cc1',\n",
    "'O=C(Nc1ncc(Cl)s1)c1cccc(S(=O)(=O)Nc2ccc(Br)cc2)c1',\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b109d84a-8d59-445b-9997-d1383ee24079",
   "metadata": {},
   "source": [
    "### Generate conformations from SMILES and save to .lmdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea582d7d-8851-4d46-880e-54867737b232",
   "metadata": {},
   "outputs": [],
   "source": [
    "def smi2coords(smi, seed):\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n",
    "    coordinate_list = []\n",
    "    res = AllChem.EmbedMolecule(mol, randomSeed=seed)\n",
    "    if res == 0:\n",
    "        try:\n",
    "            AllChem.MMFFOptimizeMolecule(mol)\n",
    "        except:\n",
    "            pass\n",
    "        coordinates = mol.GetConformer().GetPositions()\n",
    "    elif res == -1:\n",
    "        mol_tmp = Chem.MolFromSmiles(smi)\n",
    "        AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n",
    "        mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
    "        try:\n",
    "            AllChem.MMFFOptimizeMolecule(mol_tmp)\n",
    "        except:\n",
    "            pass\n",
    "        coordinates = mol_tmp.GetConformer().GetPositions()\n",
    "    assert len(atoms) == len(coordinates), \"coordinates shape is not align with {}\".format(smi)\n",
    "    coordinate_list.append(coordinates.astype(np.float32))\n",
    "    return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi}, protocol=-1)\n",
    "\n",
    "def write_lmdb(smiles_list, job_name, seed=42, outpath='./results'):\n",
    "    os.makedirs(outpath, exist_ok=True)\n",
    "    output_name = os.path.join(outpath,'{}.lmdb'.format(job_name))\n",
    "    try:\n",
    "        os.remove(output_name)\n",
    "    except:\n",
    "        pass\n",
    "    env_new = lmdb.open(\n",
    "        output_name,\n",
    "        subdir=False,\n",
    "        readonly=False,\n",
    "        lock=False,\n",
    "        readahead=False,\n",
    "        meminit=False,\n",
    "        max_readers=1,\n",
    "        map_size=int(100e9),\n",
    "    )\n",
    "    txn_write = env_new.begin(write=True)\n",
    "    for i, smiles in tqdm(enumerate(smiles_list)):\n",
    "        inner_output = smi2coords(smiles, seed=seed)\n",
    "        txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n",
    "    txn_write.commit()\n",
    "    env_new.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dad25a1a-f93e-4fdf-b389-2a3fe61a40ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "job_name = 'get_mol_repr'   # replace to your custom name\n",
    "data_path = './results'  # replace to your data path\n",
    "weight_path='../ckp/mol_pre_no_h_220816.pt'  # replace to your ckpt path\n",
    "only_polar=0  # no h\n",
    "dict_name='dict.txt'\n",
    "batch_size=16\n",
    "results_path=data_path   # replace to your save path\n",
    "write_lmdb(smi_list, job_name=job_name, seed=seed, outpath=data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12284210-7f86-4062-b291-7c077ef6f83a",
   "metadata": {},
   "source": [
    "### Infer from ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fb2391b-81b0-4b11-95ea-3b7855db9bc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: Currently, the inference is only supported to run on a single GPU. You can add CUDA_VISIBLE_DEVICES=\"0\" before the command.\n",
    "!cp ../example_data/molecule/$dict_name $data_path\n",
    "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \\\n",
    "       --results-path $results_path \\\n",
    "       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
    "       --task unimol --loss unimol_infer --arch unimol_base \\\n",
    "       --path $weight_path \\\n",
    "       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
    "       --only-polar $only_polar --dict-name $dict_name \\\n",
    "       --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8421258-eca6-4801-aadd-fc67fd928cb1",
   "metadata": {},
   "source": [
    "### Read .pkl and save results to .csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c456f31e-94fc-4593-97c9-1db7182465aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_csv_results(predict_path, results_path):\n",
    "    predict = pd.read_pickle(predict_path)\n",
    "    smi_list, mol_repr_list, pair_repr_list = [], [], []\n",
    "    for batch in predict:\n",
    "        sz = batch[\"bsz\"]\n",
    "        for i in range(sz):\n",
    "            smi_list.append(batch[\"data_name\"][i])\n",
    "            mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n",
    "            pair_repr_list.append(batch[\"pair_repr\"][i])\n",
    "    predict_df = pd.DataFrame({\"SMILES\": smi_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n",
    "    print(predict_df.head(1),predict_df.info())\n",
    "    predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n",
    "\n",
    "pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]\n",
    "get_csv_results(pkl_path, results_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
